{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'cv2'",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-bb5a46e81331>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'CUDA_VISIBLE_DEVICES'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'0'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
     ]
    }
   ],
   "source": [
    "import os,sys,glob,shutil,json\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "import cv2\n",
    "\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from tqdm import tqdm,tqdm_notebook\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch.autograd import Variable\n",
    "\n",
    "from torch.utils.data.dataset import Dataset\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- step1 定义好只读图像的Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SVHDataset(Dataset):\n",
    "    def __init__(self,img_path,img_label,transforms=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        if transforms is not None:\n",
    "            self.transforms = transforms\n",
    "        else:\n",
    "            self.transforms = None\n",
    "    def __getitem__(self,index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "\n",
    "        if self.transforms is not None:\n",
    "            img = self.transforms(img)\n",
    "\n",
    "        # setting the longst length of chars is 5\n",
    "        lbl = np.array(self.img_path[index],dtype=np.int)\n",
    "        lbl = list(lbl) + (5-len(lbl))*[10]\n",
    "        return img,torch.from_numpy(np.array(lbl[:5]))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "- step2 : fine the `Dataset` of `train data and valid data`"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_path = glob.glob('input/train/*.png')\n",
    "train_path.sort()\n",
    "train_json = json.load(open('input/train.json'))\n",
    "train_label = \n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}